{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Quantization-Aware training with range learning\n",
    "\n",
    "This notebook contains a working example of AIMET Quantization-aware training (QAT) with range learning. \n",
    "QAT with range learning is an AIMET feature that adds quantization simulation operations (also called fake quantization ops) to a trained ML model. \n",
    "A standard training pipeline is then used to train or fine-tune the model. \n",
    "The resulting model should show improved accuracy on quantized ML accelerators.\n",
    "\n",
    "The quantization parameters (like encoding min/max, scale, and offset) for activations are computed once initially. During QAT, both the model weights and quantization parameters are jointly updated to minimize the effects of quantization in the forward pass.\n",
    "\n",
    "\n",
    "## Overall flow\n",
    "\n",
    "The example follows these high-level steps:\n",
    "\n",
    "1. Instantiate the example evaluation and training pipeline\n",
    "2. Load the FP32 model and evaluate the model to find the baseline FP32 accuracy\n",
    "3. Create a quantization simulation model (with fake quantization ops) and evaluate the quantized simuation model\n",
    "4. Fine-tune the quantization simulation model and evaluate the fine-tuned simulation model, which should reflect the accuracy on a quantized ML platform\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Dataset\n",
    "\n",
    "This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.\n",
    "\n",
    "</div>\n",
    "\n",
    "Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_DIR = '/path/to/imagenet_dir'        # Replace this path with a real directory\n",
    "BATCH_SIZE = 128\n",
    "IMAGE_SIZE = (224, 224)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Instantiate the example evaluation and training datasets\n",
    "\n",
    "**Assign the training and validation dataset to dataset_train and dataset_valid respectively.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "dataset_train = dataset_valid = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "    directory=os.path.join(DATASET_DIR, \"train\"),\n",
    "    labels=\"inferred\",\n",
    "    label_mode=\"categorical\",\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    image_size=IMAGE_SIZE\n",
    ")\n",
    "dataset_valid = tf.keras.preprocessing.image_dataset_from_directory(\n",
    "    directory=os.path.join(DATASET_DIR, \"val\"),\n",
    "    labels=\"inferred\",\n",
    "    label_mode=\"categorical\",\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=False,\n",
    "    image_size=IMAGE_SIZE\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load the model and evaluate to get a baseline FP32 accuracy score\n",
    "\n",
    "**2.1 Load a pretrained ResNet50 model from Keras.** \n",
    "\n",
    "You can load any pretrained Keras model instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.resnet import ResNet50\n",
    "\n",
    "model = ResNet50(weights='imagenet')\n",
    "model.compile(optimizer=\"adam\", loss=\"categorical_crossentropy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**2.2 Compute the floating point 32-bit (FP32) accuracy of this model using the evaluate() routine.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "model.evaluate(dataset_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 3. Create a quantization simulation model and determine quantized accuracy\n",
    "\n",
    "### Fold Batch Normalization layers\n",
    "\n",
    "Before calculating the simulated quantized accuracy using QuantizationSimModel, fold the BatchNormalization (BN) layers into adjacent Convolutional layers. The BN layers that cannot be folded are left as they are.\n",
    "\n",
    "BN folding improves inference performance on quantized runtimes but can degrade accuracy on these platforms. This step simulates this on-target drop in accuracy. \n",
    "\n",
    "The following code calls AIMET to fold the BN layers of a given model. </br>\n",
    "**NOTE: During folding, a new model is returned. Please use the returned model for the rest of the pipeline.**\n",
    "\n",
    "**3.1 Use the following code to call AIMET to fold the BN layers on the model.**\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "Folding returns a new model. Use the returned model for the rest of the pipeline.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.batch_norm_fold import fold_all_batch_norms\n",
    "\n",
    "_, model = fold_all_batch_norms(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the Quantization Sim Model\n",
    "\n",
    "**3.2 Use AIMET to create a QuantizationSimModel.**\n",
    "\n",
    " In this step, AIMET inserts fake quantization ops in the model graph and configures them.\n",
    "\n",
    "Key parameters:\n",
    "\n",
    "- Setting **quant_scheme** to \"training_range_learning_with_tf_init\" enables range learning. AIMET sets scale and offset to be trainable so that they are updated during fine-tuning. \n",
    "- Setting **default_output_bw** to 8 performs all activation quantizations in the model using integer 8-bit precision\n",
    "- Setting **default_param_bw** to 8 performs all parameter quantizations in the model using integer 8-bit precision\n",
    "\n",
    "See [QuantizationSimModel in the AIMET API documentation](https://quic.github.io/aimet-pages/AimetDocs/api_docs/torch_quantsim.html#aimet_torch.quantsim.QuantizationSimModel.compute_encodings) for a full explanation of the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from aimet_tensorflow.keras.quantsim import QuantizationSimModel\n",
    "from aimet_common.defs import QuantScheme\n",
    "\n",
    "sim = QuantizationSimModel(model=model,\n",
    "                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,\n",
    "                           rounding_mode=\"nearest\",\n",
    "                           default_output_bw=8,\n",
    "                           default_param_bw=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "AIMET has added quantizer nodes to the model graph, but before the sim model can be used for inference or training, scale and offset quantization parameters must be calculated for each quantizer node by passing unlabeled data samples through the model to collect range statistics. This process is sometimes referred to as calibration. AIMET refers to it as \"computing encodings\".\n",
    "\n",
    "**3.3 Create a routine to pass unlabeled data samples through the model.** \n",
    "\n",
    "The following code is one way to write a routine that passes unlabeled samples through the model to compute encodings. It uses the existing train or validation data loader to extract samples and pass them to the model. Since there is no need to compute loss metrics, it ignores the model output.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.utils import Progbar\n",
    "from tensorflow.keras.applications.resnet import preprocess_input\n",
    "\n",
    "def pass_calibration_data(sim_model, samples):\n",
    "    dataset = dataset_valid\n",
    "    progbar = Progbar(samples)\n",
    "\n",
    "    batch_cntr = 0\n",
    "    for inputs, _ in dataset:\n",
    "        sim_model(preprocess_input(inputs))\n",
    "\n",
    "        batch_cntr += 1\n",
    "        progbar_stat_update = \\\n",
    "            batch_cntr * BATCH_SIZE if (batch_cntr * BATCH_SIZE) < samples else samples\n",
    "        progbar.update(progbar_stat_update)\n",
    "        if (batch_cntr * BATCH_SIZE) > samples:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A few notes regarding the data samples:\n",
    "\n",
    "- A very small percentage of the data samples are needed. For example, the training dataset for ImageNet has 1M samples; 500 or 1000 suffice to compute encodings.\n",
    "- The samples should be reasonably well distributed. While it's not necessary to cover all classes, avoid extreme scenarios like using only dark or only light samples. That is, using only pictures captured at night, say, could skew the results.\n",
    "\n",
    "---\n",
    "\n",
    "**3.4 Call AIMET to pass data through the model and compute the quantization encodings.** \n",
    "\n",
    "Encodings here refer to scale and offset quantization parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "sim.compute_encodings(forward_pass_callback=pass_calibration_data,\n",
    "                      forward_pass_callback_args=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**3.5 Compile the model**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "sim.compile(optimizer=\"adam\", loss=\"categorical_crossentropy\", metrics=[\"accuracy\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "The QuantizationSim model is now ready to be used for inference or training. \n",
    "\n",
    "**3.6 Pass the model to the evaluation routine to calculate a simulated quantized accuracy score.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "sim.evaluate(dataset_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## 4. Perform QAT\n",
    "\n",
    "**4.1 To perform quantization aware training (QAT), train the model for a few more epochs (typically 15-20).** \n",
    "\n",
    "As with any training job, hyper-parameters need to be searched for optimal results. Good starting points are to use a learning rate on the same order as the ending learning rate when training the original model, and to drop the learning rate by a factor of 10 every 5 epochs or so.\n",
    "\n",
    "This example trains for only 1 epoch, but you can experiment with the parameters however you like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "quantized_callback = tf.keras.callbacks.TensorBoard(log_dir=\"./log/quantized\")\n",
    "history = sim.fit(dataset_train, epochs=1, validation_data=dataset_valid, callbacks=[quantized_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**4.2 After QAT finishes, run quantization simulation inference against the validation dataset to see improvements in accuracy.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "sim.evaluate(dataset_valid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "Of course, there might be little gain in accuracy after only one epoch of training. Experiment with the hyper-parameters to get better results.\n",
    "\n",
    "## Next steps\n",
    "\n",
    "The next step is to export this model for installation on the target.\n",
    "\n",
    "**Export the model and encodings.**\n",
    "\n",
    "- Export the model with the updated weights but without the fake quant ops. \n",
    "- Export the encodings (scale and offset quantization parameters). AIMET QuantizationSimModel provides an export API for this purpose.\n",
    "\n",
    "The following code performs these exports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [],
   "source": [
    "sim.compute_encodings(forward_pass_callback=pass_calibration_data,\n",
    "                      forward_pass_callback_args=1000)\n",
    "sim.export('./data', 'model_after_qat')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For more information\n",
    "\n",
    "See the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) for details about the AIMET APIs and optional parameters.\n",
    "\n",
    "See the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/torch/quantization) to learn how to use QAT with other AIMET post-training quantization techniques."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
